import torch
from torch import Tensor
from torch.nn import Module


def test(model: Module, x: Tensor, y: Tensor):
    predict_y = model(x)
    print(predict_y)
    _, predict_y_pos = torch.max(predict_y, dim=1)
    check_num = torch.sum(predict_y_pos.eq(y), dim=0)
    print('预测成功率：', check_num * 1.0 / y.shape[0])
